Binary Tree Pruning [DFS]¶
Time: O(N); Space: O(H); med
We are given the head node root of a binary tree, where additionally every node’s value is either a 0 or a 1. Return the same tree where every subtree (of the given tree) not containing a 1 has been removed. (Recall that the subtree of a node X is X, plus every node that is a descendant of X.)
Example 1:
Input: root = {TreeNode} [1,None,0,None,None,0,1]
Output: {TreeNode} [1,None,0,None,1]
Explanation:
Only the red nodes satisfy the property “every subtree not containing a 1”.
The diagram on the right represents the answer.
Example 2:
Input: root = {TreeNode} [1,0,1,0,0,0,1]
Output: {TreeNode} [1,None,1,None,1]
Example 3:
Input: root = {TreeNode} [1,1,0,1,1,0,1,0]
Output: {TreeNode} [1,1,0,1,1,None,1]
Notes:
The binary tree will have at most 100 nodes.
The value of each node will only be 0 or 1.
[1]:
class TreeNode:
def __init__(self, x):
self.val = x
self.left = None
self.right = None
Auxiliary Tools¶¶
[2]:
from graphviz import Graph
class TreeTasks(object):
def visualize_tree(self, tree):
def add_nodes_edges(tree, dot=None):
# Create Graph (not Digraph) object
if dot is None:
dot = Graph()
dot.node(name=str(tree), label=str(tree.val))
# Add nodes
if tree.left:
dot.node(name=str(tree.left), label="."+str(tree.left.val))
dot.edge(str(tree), str(tree.left))
dot = add_nodes_edges(tree.left, dot=dot)
if tree.right:
dot.node(name=str(tree.right), label=str(tree.right.val)+".")
dot.edge(str(tree), str(tree.right))
dot = add_nodes_edges(tree.right, dot=dot)
return dot
# Add nodes recursively and create a list of edges
dot = add_nodes_edges(tree)
# Visualize the graph
display(dot)
return dot
[3]:
class Solution1(object):
def pruneTree(self, root):
"""
:type root: TreeNode
:rtype: TreeNode
"""
if not root:
return None
root.left = self.pruneTree(root.left)
root.right = self.pruneTree(root.right)
if not root.left and not root.right and root.val == 0:
return None
return root
[4]:
s = Solution1()
root = TreeNode(1)
root.right = TreeNode(0)
root.right.left = TreeNode(0)
root.right.right = TreeNode(1)
tree = s.pruneTree(root)
t = TreeTasks()
dot = t.visualize_tree(tree)
# assert res.val == 1
# assert res.right.val == 0
# assert res.right.right.val == 1
[5]:
root = TreeNode(1)
root.left = TreeNode(0)
root.right = TreeNode(1)
root.left.left = TreeNode(0)
root.left.right = TreeNode(0)
root.right.left = TreeNode(0)
root.right.right = TreeNode(1)
tree = s.pruneTree(root)
t = TreeTasks()
dot = t.visualize_tree(tree)
# assert res.val == 1
# assert res.right.val == 1
# assert res.right.right.val == 1
[6]:
root = TreeNode(1)
root.left = TreeNode(1)
root.right = TreeNode(0)
root.left.left = TreeNode(1)
root.left.right = TreeNode(1)
root.right.left = TreeNode(0)
root.right.right = TreeNode(1)
root.left.left.left = TreeNode(0)
tree = s.pruneTree(root)
t = TreeTasks()
dot = t.visualize_tree(tree)
# assert res.val == 1
# assert res.left.val == 1
# assert res.right.val == 0
# assert res.left.left.val == 1
# assert res.left.right.val == 1
# assert res.right.right.val == 1